833a7d
@@ -24,6 +24,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Set;
 import java.util.Stack;
 
 import org.apache.hadoop.hive.conf.HiveConf;
@@ -53,6 +54,7 @@
 import org.apache.hadoop.hive.ql.plan.ExprNodeColumnListDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDesc.ExprNodeDescEqualityWrapper;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicListDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeFieldDesc;
@@ -64,6 +66,7 @@
 import org.apache.hadoop.hive.ql.stats.StatsUtils;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualNS;
@@ -76,19 +79,24 @@
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotNull;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFStruct;
 import org.apache.hadoop.hive.serde.serdeConstants;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 
 public class StatsRulesProcFactory {
 
   private static final Logger LOG = LoggerFactory.getLogger(StatsRulesProcFactory.class.getName());
   private static final boolean isDebugEnabled = LOG.isDebugEnabled();
 
+
   /**
    * Collect basic statistics like number of rows, data size and column level statistics from the
    * table. Also sets the state of the available statistics. Basic and column statistics can have
@@ -299,7 +307,7 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
 
     private long evaluateExpression(Statistics stats, ExprNodeDesc pred,
         AnnotateStatsProcCtx aspCtx, List<String> neededCols,
-        FilterOperator fop, long evaluatedRowCount) throws CloneNotSupportedException {
+        FilterOperator fop, long evaluatedRowCount) throws CloneNotSupportedException, SemanticException {
       long newNumRows = 0;
       Statistics andStats = null;
 
@@ -338,6 +346,9 @@
private long evaluateExpression(Statistics stats, ExprNodeDesc pred,
               evaluatedRowCount = newNumRows;
             }
           }
+        } else if (udf instanceof GenericUDFIn) {
+          // for IN clause
+          newNumRows = evaluateInExpr(stats, pred, aspCtx, neededCols, fop);
         } else if (udf instanceof GenericUDFOPNot) {
           newNumRows = evaluateNotExpr(stats, pred, aspCtx, neededCols, fop);
         } else if (udf instanceof GenericUDFOPNotNull) {
@@ -375,9 +386,97 @@
private long evaluateExpression(Statistics stats, ExprNodeDesc pred,
       return newNumRows;
     }
 
+    private long evaluateInExpr(Statistics stats, ExprNodeDesc pred, AnnotateStatsProcCtx aspCtx,
+            List<String> neededCols, FilterOperator fop) throws SemanticException {
+
+      long numRows = stats.getNumRows();
+
+      ExprNodeGenericFuncDesc fd = (ExprNodeGenericFuncDesc) pred;
+
+      // 1. It is an IN operator, check if it uses STRUCT
+      List<ExprNodeDesc> children = fd.getChildren();
+      List<ExprNodeDesc> columns = Lists.newArrayList();
+      List<ColStatistics> columnStats = Lists.newArrayList();
+      List<Set<ExprNodeDescEqualityWrapper>> values = Lists.newArrayList();
+      ExprNodeDesc columnsChild = children.get(0);
+      boolean multiColumn;
+      if (columnsChild instanceof ExprNodeGenericFuncDesc &&
+              ((ExprNodeGenericFuncDesc) columnsChild).getGenericUDF() instanceof GenericUDFStruct) {
+        for (int j = 0; j < columnsChild.getChildren().size(); j++) {
+          ExprNodeDesc columnChild = columnsChild.getChildren().get(j);
+          // If column is not column reference , we bail out
+          if (!(columnChild instanceof ExprNodeColumnDesc)) {
+            // Default
+            return numRows / 2;
+          }
+          columns.add(columnChild);
+          final String columnName = ((ExprNodeColumnDesc)columnChild).getColumn();
+          // if column name is not contained in needed column list then it
+          // is a partition column. We do not need to evaluate partition columns
+          // in filter expression since it will be taken care by partition pruner
+          if (neededCols != null && !neededCols.contains(columnName)) {
+            // Default
+            return numRows / 2;
+          }
+          columnStats.add(stats.getColumnStatisticsFromColName(columnName));
+          values.add(Sets.<ExprNodeDescEqualityWrapper>newHashSet());
+        }
+        multiColumn = true;
+      } else {
+        // If column is not column reference , we bail out
+        if (!(columnsChild instanceof ExprNodeColumnDesc)) {
+          // Default
+          return numRows / 2;
+        }
+        columns.add(columnsChild);
+        final String columnName = ((ExprNodeColumnDesc)columnsChild).getColumn();
+        // if column name is not contained in needed column list then it
+        // is a partition column. We do not need to evaluate partition columns
+        // in filter expression since it will be taken care by partition pruner
+        if (neededCols != null && !neededCols.contains(columnName)) {
+          // Default
+          return numRows / 2;
+        }
+        columnStats.add(stats.getColumnStatisticsFromColName(columnName));
+        values.add(Sets.<ExprNodeDescEqualityWrapper>newHashSet());
+        multiColumn = false;
+      }
+
+      // 2. Extract columns and values
+      for (int i = 1; i < children.size(); i++) {
+        ExprNodeDesc child = children.get(i);
+        // If value is not a constant, we bail out
+        if (!(child instanceof ExprNodeConstantDesc)) {
+          // Default
+          return numRows / 2;
+        }
+        if (multiColumn) {
+          ExprNodeConstantDesc constantChild = (ExprNodeConstantDesc) child;
+          List<?> items = (List<?>) constantChild.getWritableObjectInspector().getWritableConstantValue();
+          List<TypeInfo> structTypes = ((StructTypeInfo) constantChild.getTypeInfo()).getAllStructFieldTypeInfos();
+          for (int j = 0; j < structTypes.size(); j++) {
+            ExprNodeConstantDesc constant = new ExprNodeConstantDesc(structTypes.get(j), items.get(j));
+            values.get(j).add(new ExprNodeDescEqualityWrapper(constant));
+          }
+        } else {
+          values.get(0).add(new ExprNodeDescEqualityWrapper(child));
+        }
+      }
+
+      // 3. Calculate IN selectivity
+      float factor = 1;
+      for (int i = 0; i < columnStats.size(); i++) {
+        long dvs = columnStats.get(i) == null ? 0 : columnStats.get(i).getCountDistint();
+        // ( num of distinct vals for col / num of rows ) * num of distinct vals for col in IN clause
+        float columnFactor = dvs == 0 ? 0.5f : ((float)dvs / numRows) * values.get(i).size();
+        factor *= columnFactor;
+      }
+      return Math.round( (double)numRows * factor);
+    }
+
     private long evaluateNotExpr(Statistics stats, ExprNodeDesc pred,
         AnnotateStatsProcCtx aspCtx, List<String> neededCols, FilterOperator fop)
-        throws CloneNotSupportedException {
+        throws CloneNotSupportedException, SemanticException {
 
       long numRows = stats.getNumRows();
 
@@ -676,7 +775,7 @@
private long evaluateComparator(Statistics stats, ExprNodeGenericFuncDesc genFun
 
     private long evaluateChildExpr(Statistics stats, ExprNodeDesc child,
         AnnotateStatsProcCtx aspCtx, List<String> neededCols,
-        FilterOperator fop, long evaluatedRowCount) throws CloneNotSupportedException {
+        FilterOperator fop, long evaluatedRowCount) throws CloneNotSupportedException, SemanticException {
 
       long numRows = stats.getNumRows();
 
@@ -761,7 +860,7 @@
private long evaluateChildExpr(Statistics stats, ExprNodeDesc child,
         } else if (udf instanceof GenericUDFOPNull) {
           return evaluateColEqualsNullExpr(stats, genFunc);
         } else if (udf instanceof GenericUDFOPAnd || udf instanceof GenericUDFOPOr
-            || udf instanceof GenericUDFOPNot) {
+            || udf instanceof GenericUDFIn || udf instanceof GenericUDFOPNot) {
           return evaluateExpression(stats, genFunc, aspCtx, neededCols, fop, evaluatedRowCount);
         }
       }
